import gc
import torch
import torch.nn as nn
import torch.nn.functional as F
from Utils import TokenList, TokenIndex, fileIO
from random import shuffle
    
class GraphTokenizer(object):

    def __init__(self,data):

        global kgc_size

        self.tokens = TokenList()
        self.token_index = TokenIndex()
        self.n_tokens = 0

        for datapoint in data:
            x = datapoint[0]
            question = x[0]
            KG = x[1]['KG']
            KG_tokens = sum(KG,[])
            datapoint_tokens = question.split(' ')+KG_tokens
            n_datapoint_tokens = len(datapoint_tokens)
            n_THEN_END_tokens = len(KG_tokens)+2
            kgc_size += n_datapoint_tokens + n_THEN_END_tokens
            self.tokens += datapoint_tokens

        self.tokens = list(set(self.tokens))+['THEN']+['END']
        self.n_tokens = len(self.tokens)

        for n in range(self.n_tokens):
            token = self.tokens[n]
            self.token_index.add(token,n)

    def encode(self,inp):

        if type(inp) == type('str'):
            q_tokens = inp.split(' ')
            return [self.token_index.get(token) for token in q_tokens]
        G_tokens = sum([triple+['THEN'] for triple in inp],[])+['END']
        return [self.token_index.get(token) for token in G_tokens]
    
    def decode(self,G_token_indices):

        return [self.tokens[token_idx] for token_idx in G_token_indices]
    
class GraphDataOps(object):

    @staticmethod
    def process_datapoint(datapoint,gtk=None):

        X, Y = [],[]
        x = datapoint[0]
        question = x[0]
        KG = x[1]['KG']
        q_encoding = gtk.encode(question)
        G_encoding = gtk.encode(KG)
        datapoint_encoding = q_encoding + G_encoding

        datapoint_len = len(datapoint_encoding)
        for t_idx in range(datapoint_len-1):
            x_sub = datapoint_encoding[:t_idx+1]
            y_sub = datapoint_encoding[t_idx+1]
            X += [x_sub]; Y += [y_sub]

        return X, Y
    
class GraphGen(nn.Module):

    def __init__(self):

        super().__init__()
        self.embeddings = nn.Embedding(kgn_tokens,kge_size)
        self.pos_embeddings = nn.Embedding(kgc_size,kge_size)
        self.ffn_i = nn.Linear(kge_size,kgh_size)
        self.ffn_H = nn.ModuleList([nn.Linear(kgh_size,kgh_size) for _ in range(kgn_layers)])
        self.ffn_o = nn.Linear(kgh_size,kge_size)
        self.head = nn.Linear(kge_size,kgn_tokens)

    def forward(self,X):

        logits = []
        for x in X:
            nx_tokens = len(x)
            e_x = self.embeddings(torch.tensor(x))
            p_e = self.pos_embeddings(torch.arange(nx_tokens))
            e_x += p_e
            e_x = F.leaky_relu(self.ffn_i(e_x))
            for ffn_h in self.ffn_H:
                e_x = F.leaky_relu(ffn_h(e_x))
            e_x = F.leaky_relu(self.ffn_o(e_x))
            e_x = self.head(e_x)
            logit_set = e_x[-1]
            logits.append(logit_set)

        return torch.row_stack(logits)
    
    def train(self,epochs = 10):

        opt = torch.optim.AdamW(self.parameters())
        CE = nn.CrossEntropyLoss()

        for epoch in tqdm(range(epochs)):
            shuffle(CLEVRER_data)
            for datapoint in CLEVRER_data:
                X, Y = GraphDataOps.process_datapoint(datapoint,gtk)
                logits = self(X)
                targets = []
                for y in Y:
                    target = [0.0 for _ in range(kgn_tokens)]
                    target[y] = 1.0
                    targets.append(target)
                targets = torch.tensor(targets)
                loss = CE(logits,targets)
                loss.backward()
                opt.step()
                opt.zero_grad()
                eval_loss, eval_acc = self.evaluate()
                print (loss.item())
                print ('eval metrics (loss,acc)', eval_loss,eval_acc)

    def evaluate(self):
        
        CE = nn.CrossEntropyLoss()
        compare_counter = 0
        true_counter = 0

        for datapoint in eval_data[:2]:
            X, Y = GraphDataOps.process_datapoint(datapoint,gtk)
            logits = self(X)
            targets = []
            for y in Y:
                target = [0.0 for _ in range(kgn_tokens)]
                target[y] = 1.0
                targets.append(target)
            targets = torch.tensor(targets)
            loss = CE(logits,targets)

            nX = len(X)
            for n in range(nX):
                x, y = [X[n]], Y[n]
                x_logits = self(x)
                dist = F.softmax(x_logits,dim=-1)
                pred = torch.argmax(dist)
                if y == pred.item():
                    true_counter += 1
                    compare_counter += 1
                elif y != pred.item():
                    compare_counter += 1

        return loss, true_counter/float(compare_counter)
    
class TextTokenizer(object):

    def __init__(self,data):

        global c_size

        self.tokens = TokenList()
        self.token_index = TokenIndex()
        self.n_tokens = 0

        for datapoint in data:
            x = datapoint[0]
            question = x[0]
            prg = datapoint[1]
            datapoint_tokens = question.split(' ')+prg
            n_datapoint_tokens = len(datapoint_tokens)
            n_END_tokens = 1
            c_size += n_datapoint_tokens + n_END_tokens
            self.tokens += datapoint_tokens

        self.tokens = list(set(self.tokens))
        self.n_tokens = len(self.tokens)

        for n in range(self.n_tokens):
            token = self.tokens[n]
            self.token_index.add(token,n)

    def encode(self,inp):

        tokens = inp.split(' ')
        return [self.token_index.get(token) for token in tokens]
    
    def decode(self,token_indices):

        return [self.tokens[token_idx] for token_idx in token_indices]
    
class DataOps(object):

    @staticmethod
    def process_datapoint(datapoint,tk=None):

        X, Y = [],[]
        x = datapoint[0]
        question = x[0]
        prg = datapoint[1]
        q_encoding = tk.encode(question)
        prg_encoding = tk.encode(' '.join(prg))
        datapoint_encoding = q_encoding + prg_encoding
        datapoint_len = len(datapoint_encoding)
        for t_idx in range(datapoint_len-1):
            x_sub = datapoint_encoding[:t_idx+1]
            y_sub = datapoint_encoding[t_idx+1]
            X += [x_sub]; Y += [y_sub]

        return X, Y
    
class TextGen(nn.Module):

    def __init__(self):

        super().__init__()
        self.embeddings = nn.Embedding(n_tokens,e_size)
        self.pos_embeddings = nn.Embedding(c_size,e_size)
        self.ffn_i = nn.Linear(e_size,h_size)
        self.ffn_H = nn.ModuleList([nn.Linear(h_size,h_size) for _ in range(n_layers)])
        self.ffn_o = nn.Linear(h_size,e_size)
        self.head = nn.Linear(e_size,n_tokens)

    def forward(self,X):

        logits = []
        for x in X:
            nx_tokens = len(x)
            e_x = self.embeddings(torch.tensor(x))
            p_e = self.pos_embeddings(torch.arange(nx_tokens))
            e_x += p_e
            e_x = F.leaky_relu(self.ffn_i(e_x))
            for ffn_h in self.ffn_H:
                e_x = F.leaky_relu(ffn_h(e_x))
            e_x = F.leaky_relu(self.ffn_o(e_x))
            e_x = self.head(e_x)
            logit_set = e_x[-1]
            logits.append(logit_set)

        return torch.row_stack(logits)
    
    def train(self,epochs = 10):

        opt = torch.optim.AdamW(self.parameters())
        CE = nn.CrossEntropyLoss()

        for epoch in range(epochs):
            shuffle(CLEVRER_data)
            for datapoint in CLEVRER_data:
                X, Y = DataOps.process_datapoint(datapoint,tk)
                logits = self(X)
                targets = []
                for y in Y:
                    target = [0.0 for _ in range(n_tokens)]
                    target[y] = 1.0
                    targets.append(target)
                targets = torch.tensor(targets)
                loss = CE(logits,targets)
                loss.backward()
                opt.step()
                opt.zero_grad()
                eval_loss, eval_acc = self.evaluate()
                print (loss.item())
                print ('eval metrics (loss,acc)', eval_loss,eval_acc)

    def evaluate(self):
        
        CE = nn.CrossEntropyLoss()
        compare_counter = 0
        true_counter = 0

        for datapoint in eval_data[:2]:
            X, Y = DataOps.process_datapoint(datapoint,tk)
            logits = self(X)
            targets = []
            for y in Y:
                target = [0.0 for _ in range(n_tokens)]
                target[y] = 1.0
                targets.append(target)
            targets = torch.tensor(targets)
            loss = CE(logits,targets)

            nX = len(X)
            for n in range(nX):
                x, y = [X[n]], Y[n]
                x_logits = self(x)
                dist = F.softmax(x_logits,dim=-1)
                pred = torch.argmax(dist)
                if y == pred.item():
                    true_counter += 1
                    compare_counter += 1
                elif y != pred.item():
                    compare_counter += 1

        return loss, true_counter/float(compare_counter)
    
class KiGM(nn.Module):

    def __init__(self):

        super().__init__()
        self.gen_modules = nn.ModuleList([GraphGen(),TextGen()])

    def forward(self,X,text=True):

        if text == True:
            return self.gen_modules[1](X)
        elif text == False:
            return self.gen_modules[0](X)
        
    def train(self,
              epochs = 10):
        
        opt = torch.optim.AdamW(self.parameters())
        CE = nn.CrossEntropyLoss()

        for epoch in range(epochs):
            shuffle(CLEVRER_data)
            for datapoint in CLEVRER_data:
                A, B = GraphDataOps.process_datapoint(datapoint,gtk)
                X, Y = DataOps.process_datapoint(datapoint,tk)
                logits_a = self(A, text=False)
                logits_x = self(X)
                targets_a, targets_x = [],[]
                for b in B:
                    target = [0.0 for _ in range(kgn_tokens)]
                    target[b] = 1.0
                    targets_a.append(target)
                for y in Y:
                    target = [0.0 for _ in range(n_tokens)]
                    target[y] = 1.0
                    targets_x.append(target)
                targets_a, targets_x = torch.tensor(targets_a), torch.tensor(targets_x)
                loss = CE(logits_a,targets_a)+CE(logits_x,targets_x)
                loss.backward()
                opt.step()
                opt.zero_grad()
                eval_loss,eval_acc = self.gen_modules[1].evaluate()
                print (loss.item())
                print ('eval metrics (loss,acc)', eval_loss,eval_acc)

if __name__ == '__main__':

    gc.collect()

    kgc_size = 0
    kgn_tokens = None
    kge_size = 96
    kgh_size = 96
    kgn_layers = 2
    epochs = 100

    c_size = 0
    n_tokens = None
    e_size = 96
    h_size = 96
    n_layers = 2
    epochs = 100

    CLEVRER_data = fileIO.read_pickle('CLEVRER_data.pkl')
    train_data = CLEVRER_data[:int(0.8*len(CLEVRER_data))]
    eval_data = CLEVRER_data[-int(0.2*len(CLEVRER_data)):]
    CLEVRER_data = train_data
    gtk = GraphTokenizer(CLEVRER_data)
    tk = TextTokenizer(CLEVRER_data)
    n_tokens = tk.n_tokens
    kgn_tokens = gtk.n_tokens
    
    model = KiGM()
    model.train()